{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🌟动手实践function call\n",
    "\n",
    "接下来，我们将使用Yi-1.5-9B-Chat模型来实现一个独立的function call示例。这个示例不依赖于任何特定的框架，而是直接使用Hugging Face的transformers库来加载和使用模型。\n",
    "\n",
    "首先，让我们安装必要的依赖：\n",
    "⚠️这里请注意你的电脑显存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "!pip install transformers torch"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来我们开始逐步构建我们的function call实现。\n",
    "和上面一样我们也使用加减乘三个函数来做示例\n",
    "#### 步骤1：导入必要的库和定义函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "import json\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "# 定义可用的函数\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    return a * b\n",
    "\n",
    "def plus(a: int, b: int) -> int:\n",
    "    return a + b\n",
    "\n",
    "def minus(a: int, b: int) -> int:\n",
    "    return a - b\n",
    "# 在这里你可以添加自己需要的函数\n",
    "# 函数映射\n",
    "available_functions = {\n",
    "    \"multiply\": multiply,\n",
    "    \"plus\": plus,\n",
    "    \"minus\": minus\n",
    "}"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 步骤2：加载Yi模型和分词器(这一步我们使用transformers进行加载)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 加载Yi模型和分词器\n",
    "model_path = \"01-ai/Yi-1.5-9B-Chat\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=\"auto\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 步骤3：实现生成响应和解析函数调用的功能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 生成Yi的回复\n",
    "def generate_response(prompt):\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "    outputs = model.generate(**inputs, temperature=0.7, top_p=0.95)\n",
    "    response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "    return response.split(\"Human:\")[0].strip()\n",
    "\n",
    "# 解析输出，提取函数信息\n",
    "def parse_function_call(response):\n",
    "    try:\n",
    "        # 尝试从回复中提取JSON格式的函数调用\n",
    "        start = response.index(\"{\")\n",
    "        end = response.rindex(\"}\") + 1\n",
    "        function_call_json = response[start:end]\n",
    "        function_call = json.loads(function_call_json)\n",
    "        return function_call\n",
    "    except (ValueError, json.JSONDecodeError):\n",
    "        return None\n",
    "\n",
    "# 执行函数调用(函数信息传递的桥梁)\n",
    "def execute_function(function_name: str, arguments: dict) -> Any:\n",
    "    if function_name in available_functions:\n",
    "        return available_functions[function_name](**arguments)\n",
    "    else:\n",
    "        raise ValueError(f\"Function {function_name} not found\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 步骤4：实现主循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 主循环\n",
    "def main():\n",
    "    # 这个提示词模版中有函数的信息，如果你需要加其它的函数，这里也要和大模型同步\n",
    "    system_prompt = \"\"\"You are an AI assistant capable of calling functions to perform tasks. When a user asks a question that requires calling a function, respond with a JSON object containing the function name and arguments. Available functions are:\n",
    "    - multiply(a: int, b: int) -> int: Multiplies two integers\n",
    "    - plus(a: int, b: int) -> int: Adds two integers\n",
    "    - minus(a: int, b: int) -> int: Subtracts two integers\n",
    "    For example, if the user asks \"What is 5 plus 3?\", respond with:\n",
    "    {\"function\": \"plus\", \"arguments\": {\"a\": 5, \"b\": 3}}\n",
    "    If no function call is needed, respond normally.\"\"\"\n",
    "\n",
    "    conversation_history = [f\"System: {system_prompt}\"]\n",
    "\n",
    "    while True:\n",
    "        user_input = input(\"Human: \")\n",
    "        if user_input.lower() == 'exit':\n",
    "            break\n",
    "\n",
    "        # 添加用户输入到对话历史\n",
    "        conversation_history.append(f\"Human: {user_input}\")\n",
    "        \n",
    "        # 构建完整的提示\n",
    "        full_prompt = \"\\n\".join(conversation_history) + \"\\nAssistant:\"\n",
    "\n",
    "        # 获取模型响应\n",
    "        response = generate_response(full_prompt)\n",
    "        print(f\"Model response: {response}\")\n",
    "\n",
    "        # 解析可能的函数调用\n",
    "        function_call = parse_function_call(response)\n",
    "\n",
    "        if function_call:\n",
    "            function_name = function_call[\"function\"]\n",
    "            arguments = function_call[\"arguments\"]\n",
    "\n",
    "            try:\n",
    "                # 执行函数\n",
    "                result = execute_function(function_name, arguments)\n",
    "\n",
    "                # 将结果添加到对话历史\n",
    "                conversation_history.append(f\"Assistant: The result of {function_name}({arguments}) is {result}\")\n",
    "                print(f\"Assistant: The result of {function_name}({arguments}) is {result}\")\n",
    "            except Exception as e:\n",
    "                error_message = f\"An error occurred: {str(e)}\"\n",
    "                conversation_history.append(f\"Assistant: {error_message}\")\n",
    "                print(f\"Assistant: {error_message}\")\n",
    "        else:\n",
    "            # 如果没有函数调用，直接将模型的响应添加到对话历史\n",
    "            conversation_history.append(f\"Assistant: {response}\")\n",
    "            print(f\"Assistant: {response}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 运行示例\n",
    "\n",
    "要运行这个示例，你需要确保已经下载了Yi-1.5-9B-Chat模型，或者将`model_path`更改为模型的实际路径。然后，你可以直接运行这个Python脚本。\n",
    "\n",
    "使用示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "Human: What is 5 plus 3?\n",
    "Model response: Here's the function call to perform the addition:\n",
    "{\"function\": \"plus\", \"arguments\": {\"a\": 5, \"b\": 3}}\n",
    "Assistant: The result of plus({'a': 5, 'b': 3}) is 8"
   ],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
